{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compound representation learning and property prediction\n",
    "\n",
    "In this tuorial, we will go through how to run a Graph Neural Network (GNN) model for compound property prediction. In particular, we will demonstrate how to pretrain and finetune the model in the downstream tasks. If you are intersted in more details, please refer to the README for \"[info graph](https://github.com/PaddlePaddle/PaddleHelix/apps/pretrained_compound/info_graph)\" and \"[pretrained GNN](https://github.com/PaddlePaddle/PaddleHelix/apps/pretrained_compound/pretrain_gnns)\".\n",
    "\n",
    "#  Part I: Pretraining\n",
    "\n",
    "In this part, we will show how to pretrain a compound GNN model. The pretraining skills here are adapted from the work of pretrain gnns, including attribute masking, context prediction and supervised pretraining.\n",
    "\n",
    "Visit `pretrain_attrmask.py` and `pretrain_supervised.py` for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.insert(0, os.getcwd() + \"/..\")\n",
    "os.chdir(\"../apps/pretrained_compound/pretrain_gnns\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Pahelix framework is build upon PaddlePaddle, which is a deep learning framework.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-05-08 15:04:44,729 - INFO - ujson not install, fail back to use json instead\n",
      "2021-05-08 15:04:44,809 - INFO - Enabling RDKit 2020.09.1 jupyter extensions\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "import paddle.nn as nn\n",
    "import pgl\n",
    "\n",
    "from pahelix.model_zoo.pretrain_gnns_model import PretrainGNNModel, AttrmaskModel\n",
    "from pahelix.datasets.zinc_dataset import load_zinc_dataset\n",
    "from pahelix.utils.splitters import RandomSplitter\n",
    "from pahelix.featurizers.pretrain_gnn_featurizer import AttrmaskTransformFn, AttrmaskCollateFn\n",
    "from pahelix.utils import load_json_config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load configurations\n",
    "Here, we use `compound_encoder_config`,`model_config` to hold the compound encoder and model configurations. `PretrainGNNModel` is the basic GNN Model used in pretrain gnns,`AttrmaskModel` is an unsupervised pretraining model which randomly masks the atom type of a node and then tries to predict the masked atom type. In the meanwhile, we use the Adam optimizer and set the learning rate to be 0.001."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/.local/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PretrainGNNModel] embed_dim:300\n",
      "[PretrainGNNModel] dropout_rate:0.5\n",
      "[PretrainGNNModel] norm_type:batch_norm\n",
      "[PretrainGNNModel] graph_norm:False\n",
      "[PretrainGNNModel] residual:False\n",
      "[PretrainGNNModel] layer_num:5\n",
      "[PretrainGNNModel] gnn_type:gin\n",
      "[PretrainGNNModel] JK:last\n",
      "[PretrainGNNModel] readout:mean\n",
      "[PretrainGNNModel] atom_names:['atomic_num', 'chiral_tag']\n",
      "[PretrainGNNModel] bond_names:['bond_dir', 'bond_type']\n"
     ]
    }
   ],
   "source": [
    "compound_encoder_config = load_json_config(\"model_configs/pregnn_paper.json\")\n",
    "model_config = load_json_config(\"model_configs/pre_Attrmask.json\")\n",
    "\n",
    "compound_encoder = PretrainGNNModel(compound_encoder_config)\n",
    "model = AttrmaskModel(model_config, compound_encoder)\n",
    "opt = paddle.optimizer.Adam(0.001, parameters=model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset loading and feature extraction\n",
    "### Download the dataset using wget\n",
    "First, we need to download a small dataset for this demo. If you do not have `wget` on your machine, you could also\n",
    "copy the url below into your web browser to download the data. But remember to copy the data manually to the\n",
    "path \"../apps/pretrained_compound/pretrain_gnns/\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-05-08 15:05:21--  https://baidu-nlp.bj.bcebos.com/PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\n",
      "Resolving baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)... 10.70.0.165\n",
      "Connecting to baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)|10.70.0.165|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 609563 (595K) [application/gzip]\n",
      "Saving to: ‘PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz.1’\n",
      "\n",
      "100%[======================================>] 609,563     --.-K/s   in 0.04s   \n",
      "\n",
      "2021-05-08 15:05:21 (13.9 MB/s) - ‘PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz.1’ saved [609563/609563]\n",
      "\n",
      "tox21  zinc_standard_agent\n"
     ]
    }
   ],
   "source": [
    "### Download a toy dataset for demonstration:\n",
    "!wget \"https://baidu-nlp.bj.bcebos.com/PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\" --no-check-certificate\n",
    "!tar -zxf \"PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\"\n",
    "!ls \"./chem_dataset_small\"\n",
    "### Download the full dataset as you want:\n",
    "# !wget \"http://snap.stanford.edu/gnn-pretrain/data/chem_dataset.zip\" --no-check-certificate\n",
    "# !unzip \"chem_dataset.zip\"\n",
    "# !ls \"./chem_dataset\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### Load the dataset  and generate features\n",
    "The Zinc dataset is used as the pretraining dataset.Here we use a toy dataset for demonstration,you can load the full dataset as you want.\n",
    "`AttrmaskTransformFn` is used along with `AttrmaskModel`.It is used to generate features. The original features are processed into features available on the network, such as smiles strings into node and edge features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset num: 1000\n"
     ]
    }
   ],
   "source": [
    "### Load the first 1000 of the toy dataset for speed up\n",
    "dataset = load_zinc_dataset(\"./chem_dataset_small/zinc_standard_agent/\")\n",
    "dataset = dataset[:1000]\n",
    "print(\"dataset num: %s\" % (len(dataset)))\n",
    "\n",
    "transform_fn = AttrmaskTransformFn()\n",
    "dataset.transform(transform_fn, num_workers=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start train\n",
    "Now we train the attrmask model for 2 epochs for demostration purposes. Here we use `AttrmaskTransformFn` to aggregate multiple samples into a mini-batch.And the data loading process is accelerated with 4 processors.Then the pretrained model is saved to \"./model/pretrain_attrmask\", which will serve as the initial model of the downstream tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/tools/paddle2.0/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.\n",
      "  \"When training, we now always track global mean and variance.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0 train/loss:2.7467122\n",
      "epoch:1 train/loss:0.93576944\n"
     ]
    }
   ],
   "source": [
    "def train(model, dataset, collate_fn, opt):\n",
    "    data_gen = dataset.get_data_loader(\n",
    "            batch_size=128, \n",
    "            num_workers=4, \n",
    "            shuffle=True,\n",
    "            collate_fn=collate_fn)\n",
    "    list_loss = []\n",
    "    model.train()\n",
    "    for graphs, masked_node_indice, masked_node_label in data_gen:\n",
    "        graphs = graphs.tensor()\n",
    "        masked_node_indice = paddle.to_tensor(masked_node_indice, 'int64')\n",
    "        masked_node_label = paddle.to_tensor(masked_node_label, 'int64')\n",
    "        loss = model(graphs, masked_node_indice, masked_node_label)\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.clear_grad()\n",
    "        list_loss.append(loss.numpy())\n",
    "    return np.mean(list_loss)\n",
    "\n",
    "collate_fn = AttrmaskCollateFn(\n",
    "        atom_names=compound_encoder_config['atom_names'], \n",
    "        bond_names=compound_encoder_config['bond_names'],\n",
    "        mask_ratio=0.15)\n",
    "\n",
    "for epoch_id in range(2):\n",
    "    train_loss = train(model, dataset, collate_fn, opt)\n",
    "    print(\"epoch:%d train/loss:%s\" % (epoch_id, train_loss))\n",
    "paddle.save(compound_encoder.state_dict(), \n",
    "        './model/pretrain_attrmask/compound_encoder.pdparams')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above is about the pretraining steps,you can adjust as needed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part II: Downstream finetuning\n",
    "Below we will introduce how to use the pretrained model for the finetuning of downstream tasks.\n",
    "\n",
    "Visit `finetune.py` for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pahelix.utils.splitters import \\\n",
    "    RandomSplitter, IndexSplitter, ScaffoldSplitter\n",
    "from pahelix.datasets import *\n",
    "\n",
    "from src.model import DownstreamModel\n",
    "from src.featurizer import DownstreamTransformFn, DownstreamCollateFn\n",
    "from src.utils import calc_rocauc_score, exempt_parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The downstream datasets are usually small and have different tasks. For example, the BBBP dataset is used for the predictions of the Blood-brain barrier permeability. The Tox21 dataset is used for the predictions of toxicity of compounds. Here we use the Tox21 dataset for demonstrations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']\n"
     ]
    }
   ],
   "source": [
    "task_names = get_default_tox21_task_names()\n",
    "print(task_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load configurations\n",
    "Here, we use `compound_encoder_config` and `model_config` to hold the model configurations. Note that the configurations of the model architecture should align with that of the pretraining model, otherwise the loading will fail. \n",
    "`DownstreamModel` is an supervised GNN model which predicts the tasks shown in `task_names`. Meanwhile, we use BCEloss as the criterion,Adam optimizer and set the lr to be 0.001."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PretrainGNNModel] embed_dim:300\n",
      "[PretrainGNNModel] dropout_rate:0.5\n",
      "[PretrainGNNModel] norm_type:batch_norm\n",
      "[PretrainGNNModel] graph_norm:False\n",
      "[PretrainGNNModel] residual:False\n",
      "[PretrainGNNModel] layer_num:5\n",
      "[PretrainGNNModel] gnn_type:gin\n",
      "[PretrainGNNModel] JK:last\n",
      "[PretrainGNNModel] readout:mean\n",
      "[PretrainGNNModel] atom_names:['atomic_num', 'chiral_tag']\n",
      "[PretrainGNNModel] bond_names:['bond_dir', 'bond_type']\n"
     ]
    }
   ],
   "source": [
    "compound_encoder_config = load_json_config(\"model_configs/pregnn_paper.json\")\n",
    "model_config = load_json_config(\"model_configs/down_linear.json\")\n",
    "model_config['num_tasks'] = len(task_names)\n",
    "\n",
    "compound_encoder = PretrainGNNModel(compound_encoder_config)\n",
    "model = DownstreamModel(model_config, compound_encoder)\n",
    "criterion = nn.BCELoss(reduction='none')\n",
    "opt = paddle.optimizer.Adam(0.001, parameters=model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load pretrained models\n",
    "Load the pretrained model in the pretraining phase. Here we load the model \"pretrain_attrmask\" as an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "compound_encoder.set_state_dict(paddle.load('./model/pretrain_attrmask/compound_encoder.pdparams'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset loading and feature extraction\n",
    "`DownstreamTransformFn` is used along with `DownstreamModel`.It is used to generate features. The original features are processed into features available on the network, such as smiles strings into node and edge features. \n",
    "\n",
    "The Tox21 dataset is used as the downstream dataset and we use `ScaffoldSplitter` to split the dataset into train/valid/test set. `ScaffoldSplitter` will firstly order the compounds according to Bemis-Murcko scaffold, then take the first `frac_train` proportion as the train set, the next `frac_valid` proportion as the valid set and the rest as the test set. `ScaffoldSplitter` can better evaluate the generalization ability of the model on out-of-distribution samples. Note that other splitters like `RandomSplitter`, `RandomScaffoldSplitter` and `IndexSplitter` is also available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RDKit WARNING: [15:07:52] WARNING: not removing hydrogen atom without neighbors\n",
      "RDKit WARNING: [15:08:03] WARNING: not removing hydrogen atom without neighbors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train/Valid/Test num: 6264/783/784\n"
     ]
    }
   ],
   "source": [
    "### Load the toy dataset:\n",
    "dataset = load_tox21_dataset(\"./chem_dataset_small/tox21\", task_names)\n",
    "### Load the full dataset:\n",
    "# dataset = load_tox21_dataset(\"./chem_dataset/tox21\", task_names)\n",
    "dataset.transform(DownstreamTransformFn(), num_workers=4)\n",
    "\n",
    "# splitter = RandomSplitter()\n",
    "splitter = ScaffoldSplitter()\n",
    "train_dataset, valid_dataset, test_dataset = splitter.split(\n",
    "        dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)\n",
    "print(\"Train/Valid/Test num: %s/%s/%s\" % (\n",
    "        len(train_dataset), len(valid_dataset), len(test_dataset)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start train\n",
    "Now we train the attrmask model for 4 epochs for demostration purposes.Here we use `DownstreamCollateFn` to aggregate multiple samples data into a mini-batch. Since each downstream task will contain more than one sub-task, the performance of the model is evaluated by the average roc-auc on all sub-tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:0 train/loss:0.2782306\n",
      "epoch:0 val/auc:0.6327719308701215\n",
      "epoch:0 test/auc:0.6230107310357754\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/tools/paddle2.0/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.\n",
      "  \"When training, we now always track global mean and variance.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:1 train/loss:0.22611414\n",
      "epoch:1 val/auc:0.6998613183307891\n",
      "epoch:1 test/auc:0.661867930217498\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/tools/paddle2.0/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.\n",
      "  \"When training, we now always track global mean and variance.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:2 train/loss:0.22136745\n",
      "epoch:2 val/auc:0.6257492122355784\n",
      "epoch:2 test/auc:0.6332904016792812\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/tools/paddle2.0/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.\n",
      "  \"When training, we now always track global mean and variance.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:3 train/loss:0.21559708\n",
      "epoch:3 val/auc:0.6922950038299994\n",
      "epoch:3 test/auc:0.6663004542741277\n"
     ]
    }
   ],
   "source": [
    "def train(model, train_dataset, collate_fn, criterion, opt):\n",
    "    data_gen = train_dataset.get_data_loader(\n",
    "            batch_size=128, \n",
    "            num_workers=4, \n",
    "            shuffle=True,\n",
    "            collate_fn=collate_fn)\n",
    "    list_loss = []\n",
    "    model.train()\n",
    "    for graphs, valids, labels in data_gen:\n",
    "        graphs = graphs.tensor()\n",
    "        labels = paddle.to_tensor(labels, 'float32')\n",
    "        valids = paddle.to_tensor(valids, 'float32')\n",
    "        preds = model(graphs)\n",
    "        loss = criterion(preds, labels)\n",
    "        loss = paddle.sum(loss * valids) / paddle.sum(valids)\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.clear_grad()\n",
    "        list_loss.append(loss.numpy())\n",
    "    return np.mean(list_loss)\n",
    "\n",
    "def evaluate(model, test_dataset, collate_fn):\n",
    "    data_gen = test_dataset.get_data_loader(\n",
    "            batch_size=128, \n",
    "            num_workers=4, \n",
    "            shuffle=False,\n",
    "            collate_fn=collate_fn)\n",
    "    total_pred = []\n",
    "    total_label = []\n",
    "    total_valid = []\n",
    "    model.eval()\n",
    "    for graphs, valids, labels in data_gen:\n",
    "        graphs = graphs.tensor()\n",
    "        labels = paddle.to_tensor(labels, 'float32')\n",
    "        valids = paddle.to_tensor(valids, 'float32')\n",
    "        preds = model(graphs)\n",
    "        total_pred.append(preds.numpy())\n",
    "        total_valid.append(valids.numpy())\n",
    "        total_label.append(labels.numpy())\n",
    "    total_pred = np.concatenate(total_pred, 0)\n",
    "    total_label = np.concatenate(total_label, 0)\n",
    "    total_valid = np.concatenate(total_valid, 0)\n",
    "    return calc_rocauc_score(total_label, total_pred, total_valid)\n",
    "\n",
    "collate_fn = DownstreamCollateFn(\n",
    "        atom_names=compound_encoder_config['atom_names'], \n",
    "        bond_names=compound_encoder_config['bond_names'])\n",
    "for epoch_id in range(4):\n",
    "    train_loss = train(model, train_dataset, collate_fn, criterion, opt)\n",
    "    val_auc = evaluate(model, valid_dataset, collate_fn)\n",
    "    test_auc = evaluate(model, test_dataset, collate_fn)\n",
    "    print(\"epoch:%s train/loss:%s\" % (epoch_id, train_loss))\n",
    "    print(\"epoch:%s val/auc:%s\" % (epoch_id, val_auc))\n",
    "    print(\"epoch:%s test/auc:%s\" % (epoch_id, test_auc))\n",
    "paddle.save(model.state_dict(), './model/tox21/model.pdparams')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part III: Downstream Inference\n",
    "In this part, we will briefly introduce how to use the trained downstream model to do inference on the given SMILES sequences."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load configurations\n",
    "This part is the basically the same as the part II."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PretrainGNNModel] embed_dim:300\n",
      "[PretrainGNNModel] dropout_rate:0.5\n",
      "[PretrainGNNModel] norm_type:batch_norm\n",
      "[PretrainGNNModel] graph_norm:False\n",
      "[PretrainGNNModel] residual:False\n",
      "[PretrainGNNModel] layer_num:5\n",
      "[PretrainGNNModel] gnn_type:gin\n",
      "[PretrainGNNModel] JK:last\n",
      "[PretrainGNNModel] readout:mean\n",
      "[PretrainGNNModel] atom_names:['atomic_num', 'chiral_tag']\n",
      "[PretrainGNNModel] bond_names:['bond_dir', 'bond_type']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/.local/lib/python3.7/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    }
   ],
   "source": [
    "compound_encoder_config = load_json_config(\"model_configs/pregnn_paper.json\")\n",
    "model_config = load_json_config(\"model_configs/down_linear.json\")\n",
    "model_config['num_tasks'] = len(task_names)\n",
    "\n",
    "compound_encoder = PretrainGNNModel(compound_encoder_config)\n",
    "model = DownstreamModel(model_config, compound_encoder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load finetuned models\n",
    "Load the finetuned model from part II."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.set_state_dict(paddle.load('./model/tox21/model.pdparams'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start Inference\n",
    "Do inference on the given SMILES sequence. We use directly call `DownstreamTransformFn` and `DownstreamCollateFn` to convert the raw SMILES sequence to the model input. \n",
    "\n",
    "Using Tox21 dataset as an example, our finetuned downstream model can make predictions over 12 sub-tasks on Tox21."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SMILES:O=C1c2ccccc2C(=O)C1c1ccc2cc(S(=O)(=O)[O-])cc(S(=O)(=O)[O-])c2n1\n",
      "Predictions:\n",
      "  NR-AR:\t0.40785\n",
      "  NR-AR-LBD:\t0.3486675\n",
      "  NR-AhR:\t0.36817572\n",
      "  NR-Aromatase:\t0.40480733\n",
      "  NR-ER:\t0.40827876\n",
      "  NR-ER-LBD:\t0.34133768\n",
      "  NR-PPAR-gamma:\t0.3827442\n",
      "  SR-ARE:\t0.37666\n",
      "  SR-ATAD5:\t0.3464901\n",
      "  SR-HSE:\t0.37402567\n",
      "  SR-MMP:\t0.47734728\n",
      "  SR-p53:\t0.39040926\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/bio/tools/paddle2.0/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.\n",
      "  \"When training, we now always track global mean and variance.\")\n"
     ]
    }
   ],
   "source": [
    "SMILES=\"O=C1c2ccccc2C(=O)C1c1ccc2cc(S(=O)(=O)[O-])cc(S(=O)(=O)[O-])c2n1\"\n",
    "transform_fn = DownstreamTransformFn(is_inference=True)\n",
    "collate_fn = DownstreamCollateFn(\n",
    "        atom_names=compound_encoder_config['atom_names'], \n",
    "        bond_names=compound_encoder_config['bond_names'],\n",
    "        is_inference=True)\n",
    "graph = collate_fn([transform_fn({'smiles': SMILES})])\n",
    "preds = model(graph.tensor()).numpy()[0]\n",
    "print('SMILES:%s' % SMILES)\n",
    "print('Predictions:')\n",
    "for name, prob in zip(task_names, preds):\n",
    "    print(\"  %s:\\t%s\" % (name, prob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
